#!/usr/bin/env python3
"""
Phase 4h: CCA Capital Account Opening Event Study

Tests whether the demographic-CA relationship "turns on" when CCA countries
open their capital accounts. This is a natural experiment:
- Treatment: capital account liberalization (KAOPEN increase)
- Pre-existing condition: demographic structure (predetermined by births decades ago)
- Outcome: does the demographic-CA link activate post-opening?

Approaches:
1. Descriptive: KAOPEN trajectories for each CCA country, identify opening episodes
2. Event study: Z × post_opening interaction within CCA subsample
3. Difference-in-differences: opened CCA vs still-closed CCA, before vs after
4. Time-varying activation: Z × KAOPEN interaction within CCA, year by year
5. Placebo: test on non-CCA countries that also had opening episodes
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats

FOLLOWUP_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup")
PROJECT_DIR = FOLLOWUP_DIR.parent
sys.path.insert(0, str(FOLLOWUP_DIR))
sys.path.insert(1, str(PROJECT_DIR))

from src.model import PanelGLS
from src.macro import EBA_COUNTRIES, SSA_COUNTRIES, filter_eba_sample

OUTPUT_DIR = FOLLOWUP_DIR / "output" / "tables"
PROCESSED_DIR = FOLLOWUP_DIR / "data" / "processed"

panel = pd.read_csv(PROCESSED_DIR / "full_panel.csv")
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()
full_sample = filter_eba_sample(est, extended=True, expansion=True)

demo_vars = ['Z_1', 'Z_2', 'Z_3']
baseline_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth', 'nfa_gdp_lag',
                     'log_rel_opw', 'health_exp_gdp']


def run_gls(df, dep_var, indep_vars):
    """Run PanelGLS and return results dict."""
    comp = df.dropna(subset=[dep_var] + indep_vars).copy()
    if comp['iso3'].nunique() < 3 or len(comp) < 20:
        return None
    y = comp[dep_var].values
    X = comp[indep_vars].values
    gls = PanelGLS()
    gls.fit(y, X, comp['iso3'].values, comp['year'].values)
    return {
        'r_squared': gls.r_squared,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'coefficients': dict(zip(indep_vars, gls.beta)),
        'std_errors': dict(zip(indep_vars, gls.se)),
        'p_values': dict(zip(indep_vars, gls.pvalues)),
    }


# ═══════════════════════════════════════════════════════════════════════════
# 1. DESCRIPTIVE: KAOPEN TRAJECTORIES FOR CCA COUNTRIES
# ═══════════════════════════════════════════════════════════════════════════
print("=" * 70)
print("CCA KAOPEN TRAJECTORIES")
print("=" * 70)

cca_countries = ['ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ', 'MDA', 'MNG', 'RUS',
                 'TJK', 'TKM', 'UKR', 'UZB']

cca_data = full_sample[full_sample['iso3'].isin(cca_countries)].copy()

print(f"\nCCA countries in sample: {sorted(cca_data['iso3'].unique())}")
print(f"Year range: {cca_data['year'].min()}-{cca_data['year'].max()}")

# Show KAOPEN trajectory for each country
print(f"\n--- KAOPEN by Country and Decade ---")
print(f"  {'Country':<6s} {'1990s':>8s} {'2000s':>8s} {'2010s':>8s} {'2020s':>8s} {'Latest':>8s} {'Δ(first→last)':>14s}")
print("  " + "-" * 60)

opening_years = {}  # Track when each country "opened" (KAOPEN crosses 0 or jumps >1)

for iso in sorted(cca_data['iso3'].unique()):
    sub = cca_data[cca_data['iso3'] == iso].sort_values('year')
    kaopen_ts = sub.set_index('year')['kaopen']

    avg_90s = kaopen_ts[(kaopen_ts.index >= 1990) & (kaopen_ts.index < 2000)].mean()
    avg_00s = kaopen_ts[(kaopen_ts.index >= 2000) & (kaopen_ts.index < 2010)].mean()
    avg_10s = kaopen_ts[(kaopen_ts.index >= 2010) & (kaopen_ts.index < 2020)].mean()
    avg_20s = kaopen_ts[(kaopen_ts.index >= 2020)].mean()
    latest = kaopen_ts.iloc[-1] if len(kaopen_ts) > 0 else np.nan
    first = kaopen_ts.iloc[0] if len(kaopen_ts) > 0 else np.nan
    delta = latest - first if not (np.isnan(latest) or np.isnan(first)) else np.nan

    print(f"  {iso:<6s} {avg_90s:8.2f} {avg_00s:8.2f} {avg_10s:8.2f} {avg_20s:8.2f} {latest:8.2f} {delta:14.2f}")

    # Identify opening year: first year KAOPEN > 0 (if it starts negative)
    # or first year of a >1 point jump
    if first < 0:
        above_zero = kaopen_ts[kaopen_ts > 0]
        if len(above_zero) > 0:
            opening_years[iso] = above_zero.index[0]
    # Also check for large jumps
    for i in range(1, len(kaopen_ts)):
        yr = kaopen_ts.index[i]
        if kaopen_ts.iloc[i] - kaopen_ts.iloc[i-1] > 1.0:
            if iso not in opening_years or yr < opening_years[iso]:
                opening_years[iso] = yr

print(f"\n--- Identified Opening Episodes ---")
for iso, yr in sorted(opening_years.items(), key=lambda x: x[1]):
    sub = cca_data[cca_data['iso3'] == iso].sort_values('year')
    kaopen_ts = sub.set_index('year')['kaopen']
    pre = kaopen_ts[kaopen_ts.index < yr].mean() if any(kaopen_ts.index < yr) else np.nan
    post = kaopen_ts[kaopen_ts.index >= yr].mean() if any(kaopen_ts.index >= yr) else np.nan
    print(f"  {iso}: opened ~{yr} (KAOPEN before: {pre:.2f}, after: {post:.2f})")

# Countries that never opened (KAOPEN stayed negative throughout)
never_opened = set(cca_data['iso3'].unique()) - set(opening_years.keys())
print(f"\n  Never opened (KAOPEN stayed ≤ 0): {sorted(never_opened)}")


# ═══════════════════════════════════════════════════════════════════════════
# 2. EVENT STUDY: Z × post_opening WITHIN CCA
# ═══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("EVENT STUDY: Z × POST_OPENING WITHIN CCA")
print("=" * 70)

# Create post_opening indicator for CCA countries
cca_data['post_opening'] = 0.0
for iso, yr in opening_years.items():
    mask = (cca_data['iso3'] == iso) & (cca_data['year'] >= yr)
    cca_data.loc[mask, 'post_opening'] = 1.0

# Also mark never-opened countries as always 0
for iso in never_opened:
    cca_data.loc[cca_data['iso3'] == iso, 'post_opening'] = 0.0

n_pre = (cca_data['post_opening'] == 0).sum()
n_post = (cca_data['post_opening'] == 1).sum()
print(f"\n  Pre-opening observations: {n_pre}")
print(f"  Post-opening observations: {n_post}")
print(f"  Countries with opening: {len(opening_years)}")
print(f"  Countries never opened: {len(never_opened)}")

# Create interaction terms
for z in demo_vars:
    cca_data[f'{z}_x_post'] = cca_data[z] * cca_data['post_opening']

# Model A: Baseline on CCA only (demographics + controls)
base_vars_cca = demo_vars + ['fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag']
# Use fewer controls for small sample
r_cca_base = run_gls(cca_data, 'ca_gdp', base_vars_cca)
if r_cca_base:
    print(f"\n  Model A (CCA baseline): R²={r_cca_base['r_squared']:.4f}, N={r_cca_base['n_obs']}")
    for v in demo_vars:
        print(f"    {v}: {r_cca_base['coefficients'][v]:.2f} (p={r_cca_base['p_values'][v]:.4f})")

# Model B: + post_opening dummy + Z × post_opening interactions
event_vars = base_vars_cca + ['post_opening'] + [f'{z}_x_post' for z in demo_vars]
r_cca_event = run_gls(cca_data, 'ca_gdp', event_vars)
if r_cca_event:
    print(f"\n  Model B (+ Z×post_opening): R²={r_cca_event['r_squared']:.4f}, N={r_cca_event['n_obs']}")
    print(f"    post_opening level: {r_cca_event['coefficients']['post_opening']:.2f} (p={r_cca_event['p_values']['post_opening']:.4f})")
    for z in demo_vars:
        pre_coef = r_cca_event['coefficients'][z]
        pre_p = r_cca_event['p_values'][z]
        int_coef = r_cca_event['coefficients'][f'{z}_x_post']
        int_p = r_cca_event['p_values'][f'{z}_x_post']
        post_coef = pre_coef + int_coef
        print(f"    {z} pre-opening:  {pre_coef:8.2f} (p={pre_p:.4f})")
        print(f"    {z} × post:       {int_coef:8.2f} (p={int_p:.4f})")
        print(f"    {z} post-opening: {post_coef:8.2f} (= pre + interaction)")

    # Joint F-test for Z × post_opening terms
    r_restricted = run_gls(cca_data, 'ca_gdp', base_vars_cca + ['post_opening'])
    if r_restricted:
        q = 3
        n = r_cca_event['n_obs']
        k = len(event_vars)
        F = ((r_cca_event['r_squared'] - r_restricted['r_squared']) / q) / \
            ((1 - r_cca_event['r_squared']) / (n - k - 1))
        p_f = 1 - stats.f.cdf(F, q, n - k - 1)
        print(f"\n    Joint F-test (Z×post jointly zero): F({q},{n-k-1})={F:.3f}, p={p_f:.4f}")


# ═══════════════════════════════════════════════════════════════════════════
# 3. SPLIT SAMPLE: PRE vs POST OPENING
# ═══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("SPLIT SAMPLE: PRE vs POST OPENING")
print("=" * 70)

# Pre-opening period only (for countries that opened)
opened_countries = set(opening_years.keys())
pre_data = cca_data[(cca_data['iso3'].isin(opened_countries)) & (cca_data['post_opening'] == 0)].copy()
post_data = cca_data[(cca_data['iso3'].isin(opened_countries)) & (cca_data['post_opening'] == 1)].copy()

print(f"\n  Opened countries: {sorted(opened_countries)}")
print(f"  Pre-opening:  {pre_data['iso3'].nunique()} countries, {len(pre_data)} obs")
print(f"  Post-opening: {post_data['iso3'].nunique()} countries, {len(post_data)} obs")

# Simple: just demographics on CA for pre and post
simple_vars = demo_vars + ['fiscal_bal_gdp', 'kaopen']

r_pre = run_gls(pre_data, 'ca_gdp', simple_vars)
r_post = run_gls(post_data, 'ca_gdp', simple_vars)

if r_pre:
    print(f"\n  PRE-opening: R²={r_pre['r_squared']:.4f}, N={r_pre['n_obs']}, countries={r_pre['n_countries']}")
    for v in demo_vars:
        print(f"    {v}: {r_pre['coefficients'][v]:8.2f} (p={r_pre['p_values'][v]:.4f})")
else:
    print(f"\n  PRE-opening: insufficient data")

if r_post:
    print(f"\n  POST-opening: R²={r_post['r_squared']:.4f}, N={r_post['n_obs']}, countries={r_post['n_countries']}")
    for v in demo_vars:
        print(f"    {v}: {r_post['coefficients'][v]:8.2f} (p={r_post['p_values'][v]:.4f})")
else:
    print(f"\n  POST-opening: insufficient data")


# ═══════════════════════════════════════════════════════════════════════════
# 4. BROADER TEST: ALL COUNTRIES WITH OPENING EPISODES
# ═══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("BROADER TEST: ALL COUNTRIES WITH KAOPEN OPENING EPISODES")
print("=" * 70)

# Identify opening episodes across ALL countries in sample
# Definition: KAOPEN increases by >1 point within a 3-year window
all_openings = {}
for iso in full_sample['iso3'].unique():
    sub = full_sample[full_sample['iso3'] == iso].sort_values('year')
    kaopen_ts = sub.set_index('year')['kaopen'].dropna()
    if len(kaopen_ts) < 5:
        continue

    # Check for sustained opening: 3-year rolling increase
    for i in range(len(kaopen_ts) - 2):
        yr = kaopen_ts.index[i]
        increase = kaopen_ts.iloc[i+2] - kaopen_ts.iloc[i]
        if increase > 1.0 and kaopen_ts.iloc[i] < 1.0:  # Must start from relatively closed
            all_openings[iso] = kaopen_ts.index[i+1]  # Middle year
            break

print(f"\n  Countries with opening episodes (KAOPEN +1 in 3yr, starting <1): {len(all_openings)}")

# Group: CCA openers, non-CCA openers, never-opened
cca_openers = {k: v for k, v in all_openings.items() if k in set(cca_countries)}
non_cca_openers = {k: v for k, v in all_openings.items() if k not in set(cca_countries)}

print(f"  CCA openers: {len(cca_openers)} — {sorted(cca_openers.items(), key=lambda x: x[1])}")
print(f"  Non-CCA openers: {len(non_cca_openers)}")

# Show non-CCA openers
for iso, yr in sorted(non_cca_openers.items(), key=lambda x: x[1]):
    sub = full_sample[full_sample['iso3'] == iso].sort_values('year')
    kaopen_ts = sub.set_index('year')['kaopen'].dropna()
    pre = kaopen_ts[kaopen_ts.index < yr].mean() if any(kaopen_ts.index < yr) else np.nan
    post = kaopen_ts[kaopen_ts.index >= yr].mean() if any(kaopen_ts.index >= yr) else np.nan
    print(f"    {iso}: opened ~{yr} (KAOPEN before: {pre:.2f}, after: {post:.2f})")

# Create post_opening for ALL openers in full sample
full_sample['post_opening'] = 0.0
full_sample['is_opener'] = 0.0
for iso, yr in all_openings.items():
    mask_post = (full_sample['iso3'] == iso) & (full_sample['year'] >= yr)
    full_sample.loc[mask_post, 'post_opening'] = 1.0
    full_sample.loc[full_sample['iso3'] == iso, 'is_opener'] = 1.0

# Create interaction terms on full sample
for z in demo_vars:
    full_sample[f'{z}_x_post_opening'] = full_sample[z] * full_sample['post_opening']

# Test on opening countries only
openers_data = full_sample[full_sample['is_opener'] == 1].copy()
print(f"\n  Opener subsample: {openers_data['iso3'].nunique()} countries, {len(openers_data)} obs")

# Model: demographics + controls + post_opening + Z × post_opening
event_vars_full = demo_vars + baseline_controls + ['post_opening'] + \
                  [f'{z}_x_post_opening' for z in demo_vars]

r_openers = run_gls(openers_data, 'ca_gdp', event_vars_full)
if r_openers:
    print(f"\n  Event study on all openers: R²={r_openers['r_squared']:.4f}, N={r_openers['n_obs']}")
    print(f"    post_opening level: {r_openers['coefficients']['post_opening']:.2f} (p={r_openers['p_values']['post_opening']:.4f})")
    for z in demo_vars:
        pre_c = r_openers['coefficients'][z]
        pre_p = r_openers['p_values'][z]
        int_c = r_openers['coefficients'][f'{z}_x_post_opening']
        int_p = r_openers['p_values'][f'{z}_x_post_opening']
        print(f"    {z} pre:  {pre_c:8.2f} (p={pre_p:.4f})")
        print(f"    {z} × post: {int_c:8.2f} (p={int_p:.4f})")

    # Joint F-test
    r_rest = run_gls(openers_data, 'ca_gdp', demo_vars + baseline_controls + ['post_opening'])
    if r_rest:
        q = 3
        n = r_openers['n_obs']
        k = len(event_vars_full)
        F = ((r_openers['r_squared'] - r_rest['r_squared']) / q) / \
            ((1 - r_openers['r_squared']) / (n - k - 1))
        p_f = 1 - stats.f.cdf(F, q, n - k - 1)
        print(f"    Joint F (Z×post): F({q},{n-k-1})={F:.3f}, p={p_f:.4f}")


# ═══════════════════════════════════════════════════════════════════════════
# 5. NON-CCA OPENERS ONLY (PLACEBO/VALIDATION)
# ═══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("NON-CCA OPENERS ONLY (VALIDATION)")
print("=" * 70)

non_cca_opener_data = full_sample[
    (full_sample['is_opener'] == 1) &
    (~full_sample['iso3'].isin(cca_countries))
].copy()

print(f"  Non-CCA openers: {non_cca_opener_data['iso3'].nunique()} countries, {len(non_cca_opener_data)} obs")

r_non_cca = run_gls(non_cca_opener_data, 'ca_gdp', event_vars_full)
if r_non_cca:
    print(f"\n  Event study on non-CCA openers: R²={r_non_cca['r_squared']:.4f}, N={r_non_cca['n_obs']}")
    print(f"    post_opening level: {r_non_cca['coefficients']['post_opening']:.2f} (p={r_non_cca['p_values']['post_opening']:.4f})")
    for z in demo_vars:
        pre_c = r_non_cca['coefficients'][z]
        pre_p = r_non_cca['p_values'][z]
        int_c = r_non_cca['coefficients'][f'{z}_x_post_opening']
        int_p = r_non_cca['p_values'][f'{z}_x_post_opening']
        print(f"    {z} pre:  {pre_c:8.2f} (p={pre_p:.4f})")
        print(f"    {z} × post: {int_c:8.2f} (p={int_p:.4f})")

    r_rest2 = run_gls(non_cca_opener_data, 'ca_gdp', demo_vars + baseline_controls + ['post_opening'])
    if r_rest2:
        q = 3
        n = r_non_cca['n_obs']
        k = len(event_vars_full)
        F = ((r_non_cca['r_squared'] - r_rest2['r_squared']) / q) / \
            ((1 - r_non_cca['r_squared']) / (n - k - 1))
        p_f = 1 - stats.f.cdf(F, q, n - k - 1)
        print(f"    Joint F (Z×post): F({q},{n-k-1})={F:.3f}, p={p_f:.4f}")


# ═══════════════════════════════════════════════════════════════════════════
# 6. FULL SAMPLE: Z × KAOPEN × post_opening (TRIPLE INTERACTION)
# ═══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("FULL SAMPLE: DOES Z×KAOPEN STRENGTHEN POST-OPENING?")
print("=" * 70)

# On the full sample, test whether Z×KAOPEN is stronger post-opening
# This uses ALL countries but focuses on the opening episode effect
for z in demo_vars:
    full_sample[f'{z}_x_kaopen'] = full_sample[z] * full_sample['kaopen']
    full_sample[f'{z}_x_kaopen_x_post'] = full_sample[z] * full_sample['kaopen'] * full_sample['post_opening']

triple_vars = demo_vars + baseline_controls + \
              [f'{z}_x_kaopen' for z in demo_vars] + \
              ['post_opening'] + \
              [f'{z}_x_kaopen_x_post' for z in demo_vars]

# Recreate openers_data from updated full_sample
openers_data = full_sample[full_sample['is_opener'] == 1].copy()
r_triple = run_gls(openers_data, 'ca_gdp', triple_vars)
if r_triple:
    print(f"\n  Triple interaction on openers: R²={r_triple['r_squared']:.4f}, N={r_triple['n_obs']}")
    for z in demo_vars:
        kaopen_c = r_triple['coefficients'][f'{z}_x_kaopen']
        kaopen_p = r_triple['p_values'][f'{z}_x_kaopen']
        triple_c = r_triple['coefficients'][f'{z}_x_kaopen_x_post']
        triple_p = r_triple['p_values'][f'{z}_x_kaopen_x_post']
        print(f"    {z}×KAOPEN (base):      {kaopen_c:8.2f} (p={kaopen_p:.4f})")
        print(f"    {z}×KAOPEN×post:        {triple_c:8.2f} (p={triple_p:.4f})")


# ═══════════════════════════════════════════════════════════════════════════
# 7. SUMMARY TABLE
# ═══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("SUMMARY: DOES THE DEMOGRAPHIC CHANNEL ACTIVATE POST-OPENING?")
print("=" * 70)

print("""
The causal hypothesis: when a country opens its capital account, the lifecycle
demographic mechanism should begin to operate — demographics should start
predicting current account balances. Before opening, capital cannot flow,
so demographic pressures are trapped domestically.

Test: Z × post_opening interaction, where post_opening = 1 after a country's
capital account liberalization episode.

If causal: Z × post_opening should be significant and have the SAME sign as
the baseline Z coefficients (demographics "turn on" after opening).
""")

# Save results
results_rows = []
for label, r in [
    ("CCA_baseline", r_cca_base),
    ("CCA_event_study", r_cca_event),
    ("CCA_pre_opening", r_pre),
    ("CCA_post_opening", r_post),
    ("All_openers_event", r_openers),
    ("Non_CCA_openers_event", r_non_cca),
    ("Triple_interaction", r_triple),
]:
    if r is None:
        continue
    for v in r['coefficients']:
        results_rows.append({
            'model': label,
            'variable': v,
            'coefficient': r['coefficients'][v],
            'std_error': r['std_errors'][v],
            'p_value': r['p_values'][v],
            'r_squared': r['r_squared'],
            'n_obs': r['n_obs'],
            'n_countries': r['n_countries'],
        })

results_df = pd.DataFrame(results_rows)
results_df.to_csv(OUTPUT_DIR / 'cca_event_study.csv', index=False)
print(f"\nSaved to {OUTPUT_DIR / 'cca_event_study.csv'}")
print("Done.")
